{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import cv2\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import transforms\n",
    "from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize\n",
    "from albumentations.pytorch import ToTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An example pipeline that uses torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchvisionDataset(Dataset):\n",
    "    def __init__(self, file_paths, labels, transform=None):\n",
    "        self.file_paths = file_paths\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.file_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        label = self.labels[idx]\n",
    "        file_path = self.file_paths[idx]\n",
    "        \n",
    "        # Read an image with PIL\n",
    "        image = Image.open(file_path)\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label\n",
    "\n",
    "\n",
    "torchvision_transform = transforms.Compose([\n",
    "    transforms.Resize((256, 256)), \n",
    "    transforms.RandomCrop(224),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(\n",
    "        mean=[0.485, 0.456, 0.406],\n",
    "        std=[0.229, 0.224, 0.225],\n",
    "    )\n",
    "])\n",
    "\n",
    "\n",
    "torchvision_dataset = TorchvisionDataset(\n",
    "    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],\n",
    "    labels=[1, 2, 3],\n",
    "    transform=torchvision_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The same pipeline with albumentations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AlbumentationsDataset(Dataset):\n",
    "    \"\"\"__init__ and __len__ functions are the same as in TorchvisionDataset\"\"\"\n",
    "    def __init__(self, file_paths, labels, transform=None):\n",
    "        self.file_paths = file_paths\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.file_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        label = self.labels[idx]\n",
    "        file_path = self.file_paths[idx]\n",
    "        \n",
    "        # Read an image with OpenCV\n",
    "        image = cv2.imread(file_path)\n",
    "        \n",
    "        # By default OpenCV uses BGR color space for color images,\n",
    "        # so we need to convert the image to RGB color space.\n",
    "        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "        if self.transform:\n",
    "            augmented = self.transform(image=image)\n",
    "            image = augmented['image']\n",
    "        return image, label\n",
    "\n",
    "\n",
    "albumentations_transform = Compose([\n",
    "    Resize(256, 256), \n",
    "    RandomCrop(224, 224),\n",
    "    HorizontalFlip(),\n",
    "    Normalize(\n",
    "        mean=[0.485, 0.456, 0.406],\n",
    "        std=[0.229, 0.224, 0.225],\n",
    "    ),\n",
    "    ToTensor()\n",
    "])\n",
    "\n",
    "\n",
    "albumentations_dataset = AlbumentationsDataset(\n",
    "    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],\n",
    "    labels=[1, 2, 3],\n",
    "    transform=albumentations_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using albumentations with PIL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use PIL instead of OpenCV while working with albumentations, but in this case, you need to convert a PIL image to a numpy array before applying transformations and then convert back the augmented numpy array to a PIL image. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AlbumentationsPilDataset(Dataset):\n",
    "    \"\"\"__init__ and __len__ functions are the same as in TorchvisionDataset\"\"\"\n",
    "    def __init__(self, file_paths, labels, transform=None):\n",
    "        self.file_paths = file_paths\n",
    "        self.labels = labels\n",
    "        self.transform = transform\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.file_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        label = self.labels[idx]\n",
    "        file_path = self.file_paths[idx]\n",
    "\n",
    "        image = Image.open(file_path)\n",
    "        \n",
    "        if self.transform:\n",
    "            # Convert PIL image to numpy array\n",
    "            image_np = np.array(image)\n",
    "            # Apply transformations\n",
    "            augmented = self.transform(image=image_np)\n",
    "            # Convert numpy array to PIL Image\n",
    "            image = Image.fromarray(augmented['image'])\n",
    "        return image, label\n",
    "\n",
    "\n",
    "albumentations_pil_transform = Compose([\n",
    "    Resize(256, 256), \n",
    "    RandomCrop(224, 224),\n",
    "    HorizontalFlip(),\n",
    "])\n",
    "\n",
    "\n",
    "# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors\n",
    "albumentations_pil_dataset = AlbumentationsPilDataset(\n",
    "    file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],\n",
    "    labels=[1, 2, 3],\n",
    "    transform=albumentations_pil_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# albumentations equivalents for torchvision transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "| torchvision transform \t| albumentations transform \t| albumentations example \t|\n",
    "|---------------------------------------------------------------------------------------------------------------------------------\t|---------------------------------------------------------------------------------------------------------------------------------------------------------\t|---------------------------------------------------------------------------------------------\t|\n",
    "| [Compose](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Compose) \t| [Compose](https://albumentations.readthedocs.io/en/latest/api/core.html#albumentations.core.composition.Compose) \t| ```Compose([Resize(256, 256), RandomCrop(224, 224)])``` \t|\n",
    "| [CenterCrop](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.CenterCrop) \t| [CenterCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CenterCrop) \t| ```CenterCrop(256, 256)``` \t|\n",
    "| [ColorJitter](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.ColorJitter) \t| [HueSaturationValue](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.HueSaturationValue) \t| ```HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5)``` \t|\n",
    "| [Pad](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Pad) \t| [PadIfNeeded](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.PadIfNeeded) \t| ```PadIfNeeded(min_height=512, min_width=512)``` \t|\n",
    "| [RandomAffine](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomAffine) \t| [ShiftScaleRotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ShiftScaleRotate) \t| ```ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5)``` \t|\n",
    "| [RandomCrop](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomCrop) \t| [RandomCrop](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop) \t| ```RandomCrop(256, 256)``` \t|\n",
    "| [RandomGrayscale](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomGrayscale) \t| [ToGray](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ToGray) \t| ```ToGray(p=0.5)``` \t|\n",
    "| [RandomHorizontalFlip](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomHorizontalFlip) \t| [HorizontalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.HorizontalFlip) \t| ```HorizontalFlip(p=0.5)``` \t|\n",
    "| [RandomRotation](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomRotation) \t| [Rotate](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Rotate) \t| ```Rotate(limit=45, p=0.5)``` \t|\n",
    "| [RandomVerticalFlip](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomVerticalFlip) \t| [VerticalFlip](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.VerticalFlip) \t| ```VerticalFlip(p=0.5)``` \t|\n",
    "| [Resize](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Resize) \t| [Resize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Resize) \t| ```Resize(256, 256)``` \t|\n",
    "| [Normalize](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Normalize) \t| [Normalize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Normalize) \t| ```Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])``` \t|"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
